-
Notifications
You must be signed in to change notification settings - Fork 82
Support gqa in aten spda #2408
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support gqa in aten spda #2408
Conversation
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2408 +/- ##
==========================================
- Coverage 70.38% 70.24% -0.14%
==========================================
Files 199 199
Lines 25223 25270 +47
Branches 2686 2693 +7
==========================================
- Hits 17753 17751 -2
- Misses 6541 6586 +45
- Partials 929 933 +4 ☔ View full report in Codecov by Sentry. |
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
return _aten_scaled_dot_product_attention_bool_mask_onnx( | ||
query, key, value, attn_mask, scale, dropout_p, enable_gqa=enable_gqa | ||
) |
Check failure
Code scanning / CodeQL
Wrong name for an argument in a call
Show autofix suggestion
Hide autofix suggestion
Copilot Autofix
AI 3 months ago
To fix the issue, the keyword argument enable_gqa
should be removed from the call to _aten_scaled_dot_product_attention_bool_mask_onnx
on line 1994. This ensures that the function is called with only the parameters it supports. The removal of enable_gqa
will not affect the functionality of _aten_scaled_dot_product_attention_bool_mask_onnx
, as it does not use this argument.
-
Copy modified line R1995
@@ -1994,3 +1994,3 @@ | ||
return _aten_scaled_dot_product_attention_bool_mask_onnx( | ||
query, key, value, attn_mask, scale, dropout_p, enable_gqa=enable_gqa | ||
query, key, value, attn_mask, scale, dropout_p | ||
) |
axis=0 | ||
) | ||
value_unsqueezed = op.Unsqueeze(value, [-2]) | ||
value_tiled = op.Tile(value_unsqueezed, op.Concat( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
op.Tile does not align to PyTorch inplementation.
if (
(q_num_heads != k_num_heads)
and (q_num_heads % k_num_heads == 0)
and (k_num_heads == v_num_heads)
):
seq_reps = q_num_heads // k_num_heads
# Interleave-repeat each KV head: [h0, h0, h1, h1, ...]
K = np.repeat(K, repeats=seq_reps, axis=1)
V = np.repeat(V, repeats=seq_reps, axis=1)
We should be able to reuse repeat_interleave
here when it's done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we use expand for repeat interleave for simplicity over tile?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if we can just adapt whatever function body is in defs.cc to torchlib? Is there any difference?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably not. I must have need using the old implementation
Fix pytorch/pytorch#151762